4 Main Results
Interpret and summarize the prediction and stability results.
Evaluate pipeline on test data.
Summarize test set prediction and/or interpretability results.
caret
# how to do cross validation
trcontrol <- caret::trainControl(
method = "cv",
number = 5,
classProbs = if (is.factor(ytrain)) TRUE else FALSE,
summaryFunction = caret::defaultSummary,
allowParallel = FALSE,
verboseIter = FALSE
)
response <- "raw"
model_list <- list(
ranger = list(tuneGrid = expand.grid(mtry = seq(sqrt(ncol(Xtrain)),
ncol(Xtrain) / 3,
length.out = 3),
splitrule = "gini",
min.node.size = 1),
importance = "impurity",
num.threads = 1),
xgbTree = list(tuneGrid = expand.grid(nrounds = c(10, 25, 50, 100, 150),
max_depth = c(3, 6),
colsample_bytree = 0.33,
eta = c(0.1, 0.3),
gamma = 0,
min_child_weight = 1,
subsample = 0.6),
nthread = 1)
)
model_fits <- list()
model_preds <- list()
model_errs <- list()
model_vimps <- list()
for (model_name in names(model_list)) {
mod <- model_list[[model_name]]
if (identical(mod, list())) {
mod <- NULL
}
mod_fit <- do.call(caret::train, args = c(list(x = as.data.frame(Xtrain),
y = ytrain,
trControl = trcontrol,
method = model_name),
mod))
model_fits[[model_name]] <- mod_fit
model_preds[[model_name]] <- predict(mod_fit, as.data.frame(Xtest),
type = response)
model_errs[[model_name]] <- caret::postResample(
pred = model_preds[[model_name]], obs = ytest
)
model_vimps[[model_name]] <- caret::varImp(mod_fit)
}
model_preds <- dplyr::bind_rows(model_preds, .id = "model")
model_errs <- dplyr::bind_rows(model_errs, .id = "model")
model_vimps <- purrr::map_dfr(model_vimps,
~.x[["importance"]] %>%
tibble::rownames_to_column("variable"),
.id = "model")tidymodels
# TODO: add code for tuning parameters
mod_recipe <- recipes::recipe(.y ~., data = splits)
# for classification
rf_model <- parsnip::rand_forest() %>%
parsnip::set_args(mtry = tune::tune()) %>%
parsnip::set_engine("ranger", importance = "impurity") %>%
parsnip::set_mode("classification")
rf_grid <- tidyr::crossing(mtry = 1:4)
svm_model <- parsnip::svm_rbf() %>%
parsnip::set_engine("kernlab") %>%
parsnip::set_mode("classification")
knn_model <- parsnip::nearest_neighbor() %>%
parsnip::set_args(neighbors = tune(), weight_func = tune()) %>%
parsnip::set_engine("kknn") %>%
parsnip::set_mode("classification")
# models <- workflowsets::workflow_set(
# preproc = list(Base = mod_recipe),
# models = list(RF = rf_model, SVM = svm_model, KNN = knn_model),
# cross = TRUE
# ) %>%
# workflowsets::option_add(grid = rf_grid, id = "Base_RF")
# model_fits <- workflowsets::workflow_map(
# object = models,
# fn = "tune_grid"
# )
model_list <- list(RF = list(model = rf_model,
grid = rf_grid),
SVM = list(model = svm_model,
grid = NULL),
KNN = list(model = knn_model,
grid = 4))
model_fits <- list()
model_preds <- list()
model_errs <- list()
model_vimps <- list()
for (model_name in names(model_list)) {
mod <- model_list[[model_name]]$model
grid <- model_list[[model_name]]$grid
if (!is.null(grid)) {
mod_fit <- workflows::workflow() %>%
workflows::add_recipe(mod_recipe) %>%
workflows::add_model(mod)
best_params <- mod_fit %>%
tune::tune_grid(resamples = rsample::vfold_cv(train_df),
grid = grid) %>%
tune::select_best(metric = "accuracy")
mod_fit <- mod_fit %>%
tune::finalize_workflow(best_params) %>%
tune::last_fit(splits)
} else {
mod_fit <- workflows::workflow() %>%
workflows::add_recipe(mod_recipe) %>%
workflows::add_model(mod) %>%
tune::last_fit(splits)
}
model_fits[[model_name]] <- mod_fit
model_preds[[model_name]] <- mod_fit %>%
tune::collect_predictions()
model_errs[[model_name]] <- mod_fit %>%
tune::collect_metrics()
model_vimps[[model_name]] <- tryCatch({
# model-specific variable importance
mod_fit %>%
workflows::extract_fit_parsnip() %>%
vip::vi()
}, error = function(e) {
# model-agnostic permutation variable importance
mod_fit %>%
workflows::extract_fit_parsnip() %>%
vip::vi(method = "permute", train = train_df, target = ".y",
feature_names = setdiff(colnames(train_df), ".y"),
pred_wrapper = predict, metric = "accuracy")
})
}
model_preds <- dplyr::bind_rows(model_preds, .id = "model")
model_errs <- dplyr::bind_rows(model_errs, .id = "model")
model_vimps <- dplyr::bind_rows(model_vimps, .id = "model")